"""
词法分析器文件
"""

from tokens import *
from error import *


class Lexer:
    """词法分析器"""

    def __init__(self, fn, text):
        """
        :param fn:  文本的来源
        :param text: 文本的内容
        """
        self.fn = fn  # 文本的来源
        self.text = text  # 文本的内容
        self.pos = Position(-1, 0, -1, fn, text)  # 当前下标
        self.current_char = None  # 当前读到的字符
        self.advance()

    def make_tokens(self) -> (list[Token], Error):
        """
        分割解析文本, 形成 token
        1. 遍历文本
        2. 将每一个部分解析出来
        :return: tokens 列表
        """
        tokens = []
        while self.current_char is not None:
            # 空格或制表符, 不管他跳过
            if self.current_char in [' ', '\t']:
                self.advance()
            elif self.current_char == '#':
                # 注释, 直接跳过
                self.make_comment()
            elif self.current_char in DIGITS:  # 获取数字类型
                res, err = self.make_number()
                if err:
                    pos_start = self.pos.copy()
                    self.advance()
                    return [], Error(pos_start, self.pos, '解析数字类型失败', f"不是小数")
                tokens.append(res)
            elif self.current_char in LETTERS:  # 获取字符串, 也就是获取变量
                tokens.append(self.make_identifier())
            elif self.current_char == '=':
                tokens.append(self.make_equals())  # 等于或者赋值号
            elif self.current_char == '<':
                tokens.append(self.make_less_than())  # < 或者 <=
            elif self.current_char == '>':
                tokens.append(self.make_great_than())  # > 或者 >=
            elif self.current_char == '!':
                res, err = self.make_not_equals()  # 匹配 !=
                if err:
                    return [], err
                tokens.append(res)
            elif self.current_char == '[':
                tokens.append(Token(TT_LSQUARE, pos_start=self.pos))
                self.advance()
            elif self.current_char == ']':
                tokens.append(Token(TT_RSQUARE, pos_start=self.pos))
                self.advance()
            elif self.current_char == '^':
                tokens.append(Token(TT_POW, pos_start=self.pos))
                self.advance()
            elif self.current_char == '+':
                tokens.append(Token(TT_PLUS, pos_start=self.pos))
                self.advance()
            elif self.current_char == '-':
                tokens.append(self.make_minus_or_arrow())
            elif self.current_char == '*':
                tokens.append(Token(TT_MUL, pos_start=self.pos))
                self.advance()
            elif self.current_char == '/':
                tokens.append(Token(TT_DIV, pos_start=self.pos))
                self.advance()
            elif self.current_char == '(':
                tokens.append(Token(TT_LPAREN, pos_start=self.pos))
                self.advance()
            elif self.current_char == ')':
                tokens.append(Token(TT_RPAREN, pos_start=self.pos))
                self.advance()
            elif self.current_char == ',':  # 匹配逗号
                tokens.append(Token(TT_COMMA, pos_start=self.pos))
                self.advance()
            elif self.current_char == '"':  # 判断字符串
                tokens.append(self.make_string())
            elif self.current_char in ';\n':  # 判断换行
                tokens.append(Token(TT_NEWLINE, pos_start=self.pos))
                self.advance()
            else:
                pos_start = self.pos.copy()
                char = self.current_char
                self.advance()
                return [], IllegalCharError(pos_start, self.pos, f"'{char}'")
        tokens.append(Token(TT_EOF, pos_start=self.pos))
        return tokens, None

    def advance(self):
        """
        获取下一个字符
        :return:
        """
        self.pos.advance(self.current_char)  # self.pos.idx += 1
        if self.pos.idx < len(self.text):
            self.current_char = self.text[self.pos.idx]  # 获取下一个字符
        else:
            self.current_char = None

    def make_number(self):
        """
        解析数字类型
        注意, 该方法不能判断负数
        :return:
        """
        num_str = ''
        dot_count = 0  # 小数点数量
        pos_start = self.pos.copy()
        while self.current_char is not None and self.current_char in DIGITS + '.':
            if self.current_char == '.':
                if dot_count == 1:
                    # 不合法, 不可能有两个小数点
                    return None, True
                dot_count += 1
                num_str += '.'
            else:
                num_str += self.current_char
            self.advance()  # 获取下一个字符
        if dot_count == 0:
            return Token(TT_INT, int(num_str), pos_start, self.pos), None
        else:
            return Token(TT_FLOAT, float(num_str), pos_start, self.pos), None

    def make_identifier(self):
        """
        识别变量
        """
        variable_str = ''
        # 记录开始位置
        pos_start = self.pos.copy()
        while self.current_char is not None and self.current_char in LETTERS_DIGITS + '_':
            variable_str += self.current_char
            self.advance()
        # 标识符是关键字还是变量名
        token_type = None
        if variable_str in KEYWORDS:
            token_type = TT_KEYWORDS
        else:
            token_type = TT_IDENTIFIER

        return Token(token_type, variable_str, pos_start, self.pos)

    def make_equals(self):
        """
        匹配 == 或者 =
        :return:
        """
        token_type = TT_EQ
        pos_start = self.pos.copy()

        self.advance()
        if self.current_char == '=':
            self.advance()
            token_type = TT_EE
        return Token(token_type, pos_start=pos_start, pos_end=self.pos)

    def make_less_than(self):
        """
        匹配 < 或者 <=
        :return:
        """
        token_type = TT_LT
        pos_start = self.pos.copy()

        self.advance()
        if self.current_char == '=':
            self.advance()
            token_type = TT_LTE
        return Token(token_type, pos_start=pos_start, pos_end=self.pos)

    def make_great_than(self):
        """
        匹配 > 或者 >=
        :return:
        """
        token_type = TT_GT
        pos_start = self.pos.copy()

        self.advance()
        if self.current_char == '=':
            self.advance()
            token_type = TT_GTE
        return Token(token_type, pos_start=pos_start, pos_end=self.pos)

    def make_not_equals(self):
        """
        匹配 != ,
        :return:
        """
        pos_start = self.pos.copy()

        self.advance()
        if self.current_char == '=':
            self.advance()
            return Token(TT_NE, pos_start=pos_start, pos_end=self.pos), None
        self.advance()
        return None, ExpectedCharError(pos_start, self.pos, "在 ! 后, 应该为 = 号")

    def make_minus_or_arrow(self):
        """
        匹配 - 或者 ->
        :return:
        """
        token_type = TT_MINUS
        pos_start = self.pos.copy()

        self.advance()
        if self.current_char == '>':
            self.advance()
            token_type = TT_ARROW
        return Token(token_type, pos_start=pos_start, pos_end=self.pos)

    def make_string(self):
        """
        匹配 字符串
        :return:
        """
        my_str = ''
        pos_start = self.pos.copy()
        escape_character = False  # 是否为转义字符 \
        escape_characters = {
            'n': '\n',
            't': '\t'
        }
        self.advance()  # 避开第一个 "

        while self.current_char is not None and (self.current_char != '"' or escape_character):
            if escape_character:
                my_str += escape_characters.get(self.current_char, self.current_char)
                escape_character = False
            else:
                if self.current_char == '\\':
                    escape_character = True
                else:
                    my_str += self.current_char
            self.advance()
        self.advance()

        return Token(TT_STRING, my_str, pos_start, self.pos)

    def make_comment(self):
        """
        跳过注释
        :return:
        """
        self.advance()
        while self.current_char != '\n':
            self.advance()
        self.advance()
